# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import random
import string
import subprocess

from ..base_constraint_satisfaction_generator import (
    BaseConstraintSatisfactionTestGenerator,
)


class CGenerator(BaseConstraintSatisfactionTestGenerator):

    def __init__(self) -> None:
        BaseConstraintSatisfactionTestGenerator.__init__(self)

    def generate_prompt(self, code: str) -> str:
        prompt = f"""Here's a c code that takes a string and prints out a score between 0 and 1.0. Can you find the string that would output a score of 1.0? Respond in json format with "answer" as the key. Only return the input string without explanation. \n\ncode: {code}"""

        return prompt

    def generate_test_case(self) -> tuple[str, bytes, str]:
        random_seed = "".join(
            random.choices(string.ascii_letters + string.digits, k=20)
        )
        self.calculate_constraints(random_seed)

        c_code = f"""
#include <stdio.h>
#include <string.h>
#include <ctype.h>

int is_prime(int n) {{
    if (n <= 1 || (n % 2 == 0 && n > 2)) {{
        return 0;
    }}
    for(int i = 3; i * i <= n; i += 2) {{
        if (n % i == 0) {{
            return 0;
        }}
    }}
    return 1;
}}

int count_vowels(char *str) {{
    int count = 0;
    for (int i = 0; str[i] != '\0'; i++) {{
        if (isalpha(str[i])) {{
            if (tolower(str[i]) == 'a' || tolower(str[i]) == 'e' || tolower(str[i]) == 'i' || tolower(str[i]) == 'o' || tolower(str[i]) == 'u') {{
                count++;
            }}
        }}
    }}
    return count;
}}
int count_consonants(char *str) {{
    int count = 0;
    for (int i = 0; str[i] != '\0'; i++) {{
        if (isalpha(str[i])) {{
            if (!(tolower(str[i]) == 'a' || tolower(str[i]) == 'e' || tolower(str[i]) == 'i' || tolower(str[i]) == 'o' || tolower(str[i]) == 'u')) {{
                count++;
            }}
        }}
    }}
    return count;
}}
int count_numeric(char *str) {{
    int count = 0;
    for (int i = 0; str[i] != '\0'; i++) {{
        if (isdigit(str[i])) {{
            count++;
        }}
    }}
    return count;
}}

int count_special(char *str) {{
    int count = 0;
    for (int i = 0; str[i] != '\0'; i++) {{
        if (!isalnum(str[i])) {{
            count++;
        }}
    }}
    return count;
}}

char* strrev(char* str) {{
    char *p1, *p2;
    if (! str || ! *str)
        return str;
    for (p1 = str, p2 = str + strlen(str) - 1; p2 > p1; ++p1, --p2) {{
        *p1 ^= *p2;
        *p2 ^= *p1;
        *p1 ^= *p2;
    }}
    return str;
}}

int main(int argc, char *argv[]) {{
    if (argc != 2) {{
        printf("Usage: program <input_string>\\n");
        return 1;
    }}
    char *input_string = argv[1];
    int length = strlen(input_string);
    int length_is_prime = is_prime(length);
    int has_uppercase = 0, has_lowercase = 0, has_digit = 0, has_special = 0;
    int has_vowel = 0, has_consonant = 0, has_whitespace = 0;
    int unique_chars = 0, is_palindrome = 0;
    int has_alphabet = 0;
    int vowel_count = 0;
    int numeric_count = 0;
    int consonant_count = 0;
    int special_count = 0;
    int uppercase_count = 0;
    int lowercase_count = 0;
    float case_ratio = 0;
    for(int i = 0; i < length; i++) {{
        if (isupper(input_string[i])) has_uppercase = 1;
        if (islower(input_string[i])) has_lowercase = 1;
        if (isdigit(input_string[i])) has_digit = 1;
        if (!isalnum(input_string[i])) has_special = 1;
        if (strchr("aeiouAEIOU", input_string[i])) has_vowel = 1;
        if (strchr("bcdfghjklmnpqrstvwxyzBCDFGHJKLMNPQRSTVWXYZ", input_string[i])) has_consonant = 1;
        if (isspace(input_string[i])) has_whitespace = 1;
        if (isalpha(input_string[i])) has_alphabet = 1;
        if (isupper(input_string[i])) uppercase_count = uppercase_count + 1;
        if (islower(input_string[i])) lowercase_count = lowercase_count + 1;
    }}
    is_palindrome = strcmp(input_string, strrev(strdup(input_string))) == 0;
    case_ratio = has_lowercase != 0 ? (float)uppercase_count / lowercase_count : -1;
    int total_tests = {self.total_tests};
    int passed_tests = 0;
    numeric_count = count_numeric(input_string);
    consonant_count = count_consonants(input_string);
    vowel_count = count_vowels(input_string);
    special_count = count_special(input_string);

    for (int i = 0; i < length; i++) {{
        int flag = 0;
        for (int j = i + 1; j < length; j++) {{
            if (input_string[i] == input_string[j]) {{
                flag = 1;
                break;
            }}
        }}
        if (flag == 0) {{
            unique_chars++;
        }}
    }}

    passed_tests += length == {self.length};
    passed_tests += length_is_prime == {int(self.prime)};
    passed_tests += (has_uppercase == {int(self.has_uppercase)});
    passed_tests += (has_lowercase == {int(self.has_lowercase)});
    passed_tests += (has_digit == {int(self.has_digit)});
    passed_tests += (has_special == {int(self.has_special)});
    passed_tests += (has_vowel == {int(self.has_vowel)});
    passed_tests += (has_consonant == {int(self.has_consonant)});
    passed_tests += (has_whitespace == {int(self.has_whitespace)});
    passed_tests += unique_chars == {self.unique_chars};
    passed_tests += (is_palindrome == {int(self.is_palindrome)});
    passed_tests += (has_alphabet == {int(self.has_alphabet)});
    passed_tests += uppercase_count == {self.uppercase_count};
    passed_tests += lowercase_count == {self.lowercase_count};
    passed_tests += (case_ratio == (float){self.case_ratio});
    passed_tests += numeric_count >= {int(self.numeric_sequence)};
    passed_tests += numeric_count == {self.digit_count};
    passed_tests += consonant_count == {self.consonant_count};
    passed_tests += vowel_count == {self.vowel_count};
    passed_tests += special_count == {self.special_count};

    float score = (float)passed_tests / total_tests;

    printf("%.2f\\n", score);
    return 0;
}}
"""
        return ("c", random_seed.encode(), c_code)

    def compile_and_run(self, code: str, input: bytes) -> str:
        # Write the C code to a file
        with open("/tmp/test_case.c", "w") as f:
            f.write(code)
        # Compile the C code into an executable
        compile_result = subprocess.run(
            ["gcc", "-o", "/tmp/test_case", "/tmp/test_case.c"],
            capture_output=True,
            text=True,
        )
        # If there were any compiler errors, print them
        if compile_result.returncode != 0:
            return compile_result.stderr
        # Run the executable with the input string
        try:
            run_result = subprocess.run(
                ["/tmp/test_case", input.decode()], capture_output=True, text=True
            )
        except Exception:
            return "0.0"
        # Delete the C file and the executable
        #    os.remove('/tmp/test_case.c')
        #    os.remove('/tmp/test_case')
        # Return the output of the executable
        return run_result.stdout

    def score_from_output(self, output: str) -> float:
        return float(output)
